from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange

import torch
import torch.nn as nn

import numpy as np

from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config

from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attention import attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
from ...enhance_a_video.enhance import get_feta_scores
from ...enhance_a_video.globals import is_enhance_enabled_single, is_enhance_enabled_double, set_num_frames
from .norm_layers import RMSNorm

from contextlib import contextmanager

@contextmanager
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
    
    old_register_parameter = torch.nn.Module.register_parameter
    if include_buffers:
        old_register_buffer = torch.nn.Module.register_buffer
    
    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)

    def register_empty_buffer(module, name, buffer, persistent=True):
        old_register_buffer(module, name, buffer, persistent=persistent)
        if buffer is not None:
            module._buffers[name] = module._buffers[name].to(device)
            
    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper
    
    if include_buffers:
        tensor_constructors_to_patch = {
            torch_function_name: getattr(torch, torch_function_name)
            for torch_function_name in ["empty", "zeros", "ones", "full"]
        }
    else:
        tensor_constructors_to_patch = {}
    
    try:
        torch.nn.Module.register_parameter = register_empty_parameter
        if include_buffers:
            torch.nn.Module.register_buffer = register_empty_buffer
        for torch_function_name in tensor_constructors_to_patch.keys():
            setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
        yield
    finally:
        torch.nn.Module.register_parameter = old_register_parameter
        if include_buffers:
            torch.nn.Module.register_buffer = old_register_buffer
        for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
            setattr(torch, torch_function_name, old_torch_function)

class MMDoubleStreamBlock(nn.Module):
    """
    A multimodal dit block with seperate modulation for
    text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
                                     (Flux.1): https://github.com/black-forest-labs/flux
    """

    def __init__(
        self,
        hidden_size: int,
        heads_num: int,
        mlp_width_ratio: float,
        mlp_act_type: str = "gelu_tanh",
        qk_norm: bool = True,
        qk_norm_type: str = "rms",
        qkv_bias: bool = False,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        attention_mode: str = "sdpa",
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.attention_mode = attention_mode

        self.deterministic = False
        self.heads_num = heads_num
        head_dim = hidden_size // heads_num
        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

        self.img_mod = ModulateDiT(
            hidden_size,
            factor=6,
            act_layer=get_activation_layer("silu"),
            **factory_kwargs,
        )
        self.img_norm1 = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
        )

        self.img_attn_qkv = nn.Linear(
            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
        )
        qk_norm_layer = get_norm_layer(qk_norm_type)
        self.img_attn_q_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )
        self.img_attn_k_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )
        self.img_attn_proj = nn.Linear(
            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
        )

        self.img_norm2 = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
        )
        self.img_mlp = MLP(
            hidden_size,
            mlp_hidden_dim,
            act_layer=get_activation_layer(mlp_act_type),
            bias=True,
            **factory_kwargs,
        )

        self.txt_mod = ModulateDiT(
            hidden_size,
            factor=6,
            act_layer=get_activation_layer("silu"),
            **factory_kwargs,
        )
        self.txt_norm1 = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
        )

        self.txt_attn_qkv = nn.Linear(
            hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
        )
        self.txt_attn_q_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )
        self.txt_attn_k_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )
        self.txt_attn_proj = nn.Linear(
            hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
        )

        self.txt_norm2 = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
        )
        self.txt_mlp = MLP(
            hidden_size,
            mlp_hidden_dim,
            act_layer=get_activation_layer(mlp_act_type),
            bias=True,
            **factory_kwargs,
        )

    def enable_deterministic(self):
        self.deterministic = True

    def disable_deterministic(self):
        self.deterministic = False

    def forward(
        self,
        img: torch.Tensor,
        txt: torch.Tensor,
        vec: torch.Tensor,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: tuple = None,
        attn_mask: Optional[torch.Tensor] = None,
        upcast_rope: bool = True,
        token_replace_vec: torch.Tensor = None,
        first_frame_token_num: int = None,
        condition_type: str = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if condition_type == "token_replace":
            img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
                                                            token_replace_vec=token_replace_vec)
            (img_mod1_shift,
             img_mod1_scale,
             img_mod1_gate,
             img_mod2_shift,
             img_mod2_scale,
             img_mod2_gate) = img_mod1.chunk(6, dim=-1)
            (tr_img_mod1_shift,
             tr_img_mod1_scale,
             tr_img_mod1_gate,
             tr_img_mod2_shift,
             tr_img_mod2_scale,
             tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
        else:
            (
                img_mod1_shift,
                img_mod1_scale,
                img_mod1_gate,
                img_mod2_shift,
                img_mod2_scale,
                img_mod2_gate,
            ) = self.img_mod(vec).chunk(6, dim=-1)

        (
            txt_mod1_shift,
            txt_mod1_scale,
            txt_mod1_gate,
            txt_mod2_shift,
            txt_mod2_scale,
            txt_mod2_gate,
        ) = self.txt_mod(vec).chunk(6, dim=-1)

        # Prepare image for attention.
        img_modulated = self.img_norm1(img)
        if condition_type == "token_replace":
            img_modulated = modulate(
                img_modulated, shift=img_mod1_shift, scale=img_mod1_scale, condition_type=condition_type,
                tr_shift=tr_img_mod1_shift, tr_scale=tr_img_mod1_scale,
                first_frame_token_num=first_frame_token_num
            )
        else:
            img_modulated = modulate(
                img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
            )
        img_qkv = self.img_attn_qkv(img_modulated)
        img_q, img_k, img_v = rearrange(
            img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
        )
        # Apply QK-Norm if needed
        img_q = self.img_attn_q_norm(img_q).to(img_v)
        img_k = self.img_attn_k_norm(img_k).to(img_v)

        # Apply RoPE if needed.
        if freqs_cis is not None:
            img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, upcast=upcast_rope)
            
        # Prepare txt for attention.
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = modulate(
            txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
        )
        txt_qkv = self.txt_attn_qkv(txt_modulated)
        txt_q, txt_k, txt_v = rearrange(
            txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
        )
                
        # Apply QK-Norm if needed.
        txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
        txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)

        if is_enhance_enabled_double():
            feta_scores = get_feta_scores(img_q, img_k)

        # Run actual attention.
        q = torch.cat((img_q, txt_q), dim=1)
        k = torch.cat((img_k, txt_k), dim=1)
        v = torch.cat((img_v, txt_v), dim=1)

        attn = attention(
            q,
            k,
            v,
            heads = self.heads_num,
            mode=self.attention_mode,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_kv=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_kv=max_seqlen_kv,
            batch_size=img_k.shape[0],
            attn_mask=attn_mask
        )

        img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
        if is_enhance_enabled_double():
            img_attn *= feta_scores

        # Calculate the img bloks.
        if condition_type == "token_replace":
            img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate, condition_type=condition_type,
                                   tr_gate=tr_img_mod1_gate, first_frame_token_num=first_frame_token_num)
            img = img + apply_gate(
                self.img_mlp(
                    modulate(
                        self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale, condition_type=condition_type,
                        tr_shift=tr_img_mod2_shift, tr_scale=tr_img_mod2_scale, first_frame_token_num=first_frame_token_num
                    )
                ),
                gate=img_mod2_gate, condition_type=condition_type,
                tr_gate=tr_img_mod2_gate, first_frame_token_num=first_frame_token_num
            )
        else:
            img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
            img = img + apply_gate(
                self.img_mlp(
                    modulate(
                        self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
                    )
                ),
                gate=img_mod2_gate,
            )

        # Calculate the txt bloks.
        txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
        txt = txt + apply_gate(
            self.txt_mlp(
                modulate(
                    self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
                )
            ),
            gate=txt_mod2_gate,
        )

        return img, txt


class MMSingleStreamBlock(nn.Module):
    """
    A DiT block with parallel linear layers as described in
    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
    Also refer to (SD3): https://arxiv.org/abs/2403.03206
                  (Flux.1): https://github.com/black-forest-labs/flux
    """

    def __init__(
        self,
        hidden_size: int,
        heads_num: int,
        mlp_width_ratio: float = 4.0,
        mlp_act_type: str = "gelu_tanh",
        qk_norm: bool = True,
        qk_norm_type: str = "rms",
        qk_scale: float = None,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        attention_mode: str = "sdpa",
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.attention_mode = attention_mode

        self.deterministic = False
        self.hidden_size = hidden_size
        self.heads_num = heads_num
        head_dim = hidden_size // heads_num
        mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
        self.mlp_hidden_dim = mlp_hidden_dim
        self.scale = qk_scale or head_dim ** -0.5

        # qkv and mlp_in
        self.linear1 = nn.Linear(
            hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
        )
        # proj and mlp_out
        self.linear2 = nn.Linear(
            hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
        )

        qk_norm_layer = get_norm_layer(qk_norm_type)
        self.q_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )
        self.k_norm = (
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
            if qk_norm
            else nn.Identity()
        )

        self.pre_norm = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
        )

        self.mlp_act = get_activation_layer(mlp_act_type)()
        self.modulation = ModulateDiT(
            hidden_size,
            factor=3,
            act_layer=get_activation_layer("silu"),
            **factory_kwargs,
        )

    def enable_deterministic(self):
        self.deterministic = True

    def disable_deterministic(self):
        self.deterministic = False

    def forward(
        self,
        x: torch.Tensor,
        vec: torch.Tensor,
        txt_len: int,
        cu_seqlens_q: Optional[torch.Tensor] = None,
        cu_seqlens_kv: Optional[torch.Tensor] = None,
        max_seqlen_q: Optional[int] = None,
        max_seqlen_kv: Optional[int] = None,
        freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
        attn_mask: Optional[torch.Tensor] = None,
        upcast_rope: bool = True,
        token_replace_vec: torch.Tensor = None,
        first_frame_token_num: int = None,
        condition_type: str = None,
        stg_mode: Optional[str] = None,
        
    ) -> torch.Tensor:
        if condition_type == "token_replace":
            mod, tr_mod = self.modulation(vec,
                                          condition_type=condition_type,
                                          token_replace_vec=token_replace_vec)
            (mod_shift,
             mod_scale,
             mod_gate) = mod.chunk(3, dim=-1)
            (tr_mod_shift,
             tr_mod_scale,
             tr_mod_gate) = tr_mod.chunk(3, dim=-1)
        else:
            mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
        if condition_type == "token_replace":
            x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale, condition_type=condition_type,
                             tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, first_frame_token_num=first_frame_token_num)
        else:
            x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
        qkv, mlp = torch.split(
            self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
        )

        q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)

        # Apply QK-Norm if needed.
        q = self.q_norm(q).to(v)
        k = self.k_norm(k).to(v)

        # Apply RoPE if needed.
        if freqs_cis is not None:
            img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
            img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
            img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, upcast=upcast_rope)
            # assert (
            #     img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
            # ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
            q = torch.cat((img_q, txt_q), dim=1)
            k = torch.cat((img_k, txt_k), dim=1)

        if is_enhance_enabled_single():
            feta_scores = get_feta_scores(img_q, img_k)

        # Compute attention.
        #assert (
        #    cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
        #), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
        if stg_mode is not None:
            if stg_mode == "STG-A":
                attn = attention(
                    q,
                    k,
                    v,
                    heads = self.heads_num,
                    mode=self.attention_mode,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
                    batch_size=x.shape[0],
                    do_stg=True,
                    txt_len=txt_len,
                    attn_mask=attn_mask
                )
                output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
                return x + apply_gate(output, gate=mod_gate)
            elif stg_mode == "STG-R":
                attn = attention(
                    q,
                    k,
                    v,
                    heads = self.heads_num,
                    mode=self.attention_mode,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_kv=cu_seqlens_kv,
                    max_seqlen_q=max_seqlen_q,
                    max_seqlen_kv=max_seqlen_kv,
                    batch_size=x.shape[0],
                    attn_mask=attn_mask
                )
                # Compute activation in mlp stream, cat again and run second linear layer.
                output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
                output = apply_gate(output, gate=mod_gate)
                batch_size = output.shape[0]
                output[:batch_size-1, :, :] = 0
                return x + output
        else:
            attn = attention(
                q,
                k,
                v,
                heads = self.heads_num,
                mode=self.attention_mode,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_kv=cu_seqlens_kv,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_kv=max_seqlen_kv,
                batch_size=x.shape[0],
                attn_mask=attn_mask
            )
            if is_enhance_enabled_single():
                attn *= feta_scores
                #attn[:, :-txt_len, :] *= feta_scores
        
            # Compute activation in mlp stream, cat again and run second linear layer.
            output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
            if condition_type == "token_replace":
                output = x + apply_gate(output, gate=mod_gate, condition_type=condition_type,
                                        tr_gate=tr_mod_gate, first_frame_token_num=first_frame_token_num)
                return output
            else:
                return x + apply_gate(output, gate=mod_gate)


class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
    """
    HunyuanVideo Transformer backbone

    Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.

    Reference:
    [1] Flux.1: https://github.com/black-forest-labs/flux
    [2] MMDiT: http://arxiv.org/abs/2403.03206

    Parameters
    ----------
    args: argparse.Namespace
        The arguments parsed by argparse.
    patch_size: list
        The size of the patch.
    in_channels: int
        The number of input channels.
    out_channels: int
        The number of output channels.
    hidden_size: int
        The hidden size of the transformer backbone.
    heads_num: int
        The number of attention heads.
    mlp_width_ratio: float
        The ratio of the hidden size of the MLP in the transformer block.
    mlp_act_type: str
        The activation function of the MLP in the transformer block.
    depth_double_blocks: int
        The number of transformer blocks in the double blocks.
    depth_single_blocks: int
        The number of transformer blocks in the single blocks.
    rope_dim_list: list
        The dimension of the rotary embedding for t, h, w.
    qkv_bias: bool
        Whether to use bias in the qkv linear layer.
    qk_norm: bool
        Whether to use qk norm.
    qk_norm_type: str
        The type of qk norm.
    guidance_embed: bool
        Whether to use guidance embedding for distillation.
    text_projection: str
        The type of the text projection, default is single_refiner.
    use_attention_mask: bool
        Whether to use attention mask for text encoder.
    dtype: torch.dtype
        The dtype of the model.
    device: torch.device
        The device of the model.
    """

    @register_to_config
    def __init__(
        self,
        patch_size: list = [1, 2, 2],
        in_channels: int = 4,  # Should be VAE.config.latent_channels.
        out_channels: int = None,
        hidden_size: int = 3072,
        heads_num: int = 24,
        mlp_width_ratio: float = 4.0,
        mlp_act_type: str = "gelu_tanh",
        mm_double_blocks_depth: int = 20,
        mm_single_blocks_depth: int = 40,
        rope_dim_list: List[int] = [16, 56, 56],
        qkv_bias: bool = True,
        qk_norm: bool = True,
        qk_norm_type: str = "rms",
        guidance_embed: bool = False,  # For modulation.
        text_projection: str = "single_refiner",
        use_attention_mask: bool = True,
        text_states_dim: int = 4096,
        text_states_dim_2: int = 768,
        i2v_condition_type: str = "latent_concat",
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        main_device: Optional[torch.device] = None,
        offload_device: Optional[torch.device] = None,
        attention_mode: str = "sdpa",
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.unpatchify_channels = self.out_channels
        self.guidance_embed = guidance_embed
        self.rope_dim_list = rope_dim_list

        self.main_device = main_device
        self.offload_device = offload_device
        self.attention_mode = attention_mode
        self.i2v_condition_type = i2v_condition_type

        # Text projection. Default to linear projection.
        # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
        self.use_attention_mask = use_attention_mask
        self.text_projection = text_projection

        self.text_states_dim = text_states_dim
        self.text_states_dim_2 = text_states_dim_2

        if hidden_size % heads_num != 0:
            raise ValueError(
                f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
            )
        pe_dim = hidden_size // heads_num
        if sum(rope_dim_list) != pe_dim:
            raise ValueError(
                f"Got {rope_dim_list} but expected positional dim {pe_dim}"
            )
        self.hidden_size = hidden_size
        self.heads_num = heads_num

        # image projection
        self.img_in = PatchEmbed(
            self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
        )

        # text projection
        if self.text_projection == "linear":
            self.txt_in = TextProjection(
                self.text_states_dim,
                self.hidden_size,
                get_activation_layer("silu"),
                **factory_kwargs,
            )
        elif self.text_projection == "single_refiner":
            self.txt_in = SingleTokenRefiner(
                self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
            )
        else:
            raise NotImplementedError(
                f"Unsupported text_projection: {self.text_projection}"
            )

        # time modulation
        self.time_in = TimestepEmbedder(
            self.hidden_size, get_activation_layer("silu"), **factory_kwargs
        )

        # text modulation
        self.vector_in = MLPEmbedder(
            self.text_states_dim_2, self.hidden_size, **factory_kwargs
        )

        # guidance modulation
        self.guidance_in = (
            TimestepEmbedder(
                self.hidden_size, get_activation_layer("silu"), **factory_kwargs
            )
            if guidance_embed
            else None
        )

        # double blocks
        self.double_blocks = nn.ModuleList(
            [
                MMDoubleStreamBlock(
                    self.hidden_size,
                    self.heads_num,
                    mlp_width_ratio=mlp_width_ratio,
                    mlp_act_type=mlp_act_type,
                    qk_norm=qk_norm,
                    qk_norm_type=qk_norm_type,
                    qkv_bias=qkv_bias,
                    attention_mode=attention_mode,
                    **factory_kwargs,
                )
                for _ in range(mm_double_blocks_depth)
            ]
        )

        # single blocks
        self.single_blocks = nn.ModuleList(
            [
                MMSingleStreamBlock(
                    self.hidden_size,
                    self.heads_num,
                    mlp_width_ratio=mlp_width_ratio,
                    mlp_act_type=mlp_act_type,
                    qk_norm=qk_norm,
                    qk_norm_type=qk_norm_type,
                    attention_mode=attention_mode,
                    **factory_kwargs,
                )
                for _ in range(mm_single_blocks_depth)
            ]
        )

        self.final_layer = FinalLayer(
            self.hidden_size,
            self.patch_size,
            self.out_channels,
            get_activation_layer("silu"),
            **factory_kwargs,
        )

        self.upcast_rope = True
        
        #init block swap variables
        self.double_blocks_to_swap = -1
        self.single_blocks_to_swap = -1
        self.offload_txt_in = False
        self.offload_img_in = False

        #init TeaCache variables
        self.enable_teacache = False
        self.cnt = 0
        self.num_steps = 0
        self.teacache_skipped_steps = 0
        self.rel_l1_thresh = 0.15
        self.accumulated_rel_l1_distance = 0
        self.previous_modulated_input = None
        self.previous_residual = None
        self.last_dimensions = None
        self.last_frame_count = None
        self.teacache_device = None

    # thanks @2kpr for the initial block swap code!
    def block_swap(self, double_blocks_to_swap, single_blocks_to_swap, offload_txt_in=False, offload_img_in=False):
        print(f"Swapping {double_blocks_to_swap + 1} double blocks and {single_blocks_to_swap + 1} single blocks")
        self.double_blocks_to_swap = double_blocks_to_swap
        self.single_blocks_to_swap = single_blocks_to_swap
        self.offload_txt_in = offload_txt_in
        self.offload_img_in = offload_img_in
        for b, block in enumerate(self.double_blocks):
            if b > self.double_blocks_to_swap:
                #print(f"Moving double_block {b} to main device")
                block.to(self.main_device)
            else:
                #print(f"Moving double_block {b} to offload_device")
                block.to(self.offload_device)
        for b, block in enumerate(self.single_blocks):
            if b > self.single_blocks_to_swap:
                block.to(self.main_device)
            else:
                block.to(self.offload_device)

    def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
        def cast_to(weight, dtype=None, device=None, copy=False):
            if device is None or weight.device == device:
                if not copy:
                    if dtype is None or weight.dtype == dtype:
                        return weight
                return weight.to(dtype=dtype, copy=copy)

            r = torch.empty_like(weight, dtype=dtype, device=device)
            r.copy_(weight)
            return r

        def cast_weight(s, input=None, dtype=None, device=None):
            if input is not None:
                if dtype is None:
                    dtype = input.dtype
                if device is None:
                    device = input.device
            weight = cast_to(s.weight, dtype, device)
            return weight

        def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
            if input is not None:
                if dtype is None:
                    dtype = input.dtype
                if bias_dtype is None:
                    bias_dtype = dtype
                if device is None:
                    device = input.device
            weight = cast_to(s.weight, dtype, device)
            bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
            return weight, bias

        class quantized_layer:
            class Linear(torch.nn.Linear):
                def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
                    super().__init__(*args, **kwargs)
                    self.dtype = dtype
                    self.device = device

                def block_forward_(self, x, i, j, dtype, device):
                    weight_ = cast_to(
                        self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
                        dtype=dtype, device=device
                    )
                    if self.bias is None or i > 0:
                        bias_ = None
                    else:
                        bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
                    x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
                    y_ = torch.nn.functional.linear(x_, weight_, bias_)
                    del x_, weight_, bias_
                    torch.cuda.empty_cache()
                    return y_
                
                def block_forward(self, x, **kwargs):
                    # This feature can only reduce 2GB VRAM, so we disable it.
                    y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
                    for i in range((self.in_features + self.block_size - 1) // self.block_size):
                        for j in range((self.out_features + self.block_size - 1) // self.block_size):
                            y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
                    return y
                    
                def forward(self, x, **kwargs):
                    weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
                    return torch.nn.functional.linear(x, weight, bias)

            
            class RMSNorm(torch.nn.Module):
                def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
                    super().__init__()
                    self.module = module
                    self.dtype = dtype
                    self.device = device
                    
                def forward(self, hidden_states, **kwargs):
                    input_dtype = hidden_states.dtype
                    variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
                    hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
                    hidden_states = hidden_states.to(input_dtype)
                    if self.module.weight is not None:
                        weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
                        hidden_states = hidden_states * weight
                    return hidden_states
                
            class Conv3d(torch.nn.Conv3d):
                def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
                    super().__init__(*args, **kwargs)
                    self.dtype = dtype
                    self.device = device
                    
                def forward(self, x):
                    weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
                    return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
                
            class LayerNorm(torch.nn.LayerNorm):
                def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
                    super().__init__(*args, **kwargs)
                    self.dtype = dtype
                    self.device = device
                    
                def forward(self, x):
                    if self.weight is not None and self.bias is not None:
                        weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
                        return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
                    else:
                        return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
                for name, module in model.named_children():
                    if isinstance(module, torch.nn.Linear):
                        with init_weights_on_device():
                            new_layer = quantized_layer.Linear(
                                module.in_features, module.out_features, bias=module.bias is not None,
                                dtype=dtype, device=device
                            )
                        new_layer.load_state_dict(module.state_dict(), assign=True)
                        setattr(model, name, new_layer)
                    elif isinstance(module, torch.nn.Conv3d):
                        with init_weights_on_device():
                            new_layer = quantized_layer.Conv3d(
                                module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
                                dtype=dtype, device=device
                            )
                        new_layer.load_state_dict(module.state_dict(), assign=True)
                        setattr(model, name, new_layer)
                    elif isinstance(module, RMSNorm):
                        new_layer = quantized_layer.RMSNorm(
                            module,
                            dtype=dtype, device=device
                        )
                        setattr(model, name, new_layer)
                    elif isinstance(module, torch.nn.LayerNorm):
                        with init_weights_on_device():
                            new_layer = quantized_layer.LayerNorm(
                                module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
                                dtype=dtype, device=device
                            )
                        new_layer.load_state_dict(module.state_dict(), assign=True)
                        setattr(model, name, new_layer)
                    else:
                        replace_layer(module, dtype=dtype, device=device)

        replace_layer(self, dtype=dtype, device=device)

    def enable_deterministic(self):
        for block in self.double_blocks:
            block.enable_deterministic()
        for block in self.single_blocks:
            block.enable_deterministic()

    def disable_deterministic(self):
        for block in self.double_blocks:
            block.disable_deterministic()
        for block in self.single_blocks:
            block.disable_deterministic()

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,  # Should be in range(0, 1000).
        text_states: torch.Tensor = None,
        text_mask: torch.Tensor = None,  # Now we don't use it.
        text_states_2: Optional[torch.Tensor] = None,  # Text embedding for modulation.
        freqs_cos: Optional[torch.Tensor] = None,
        freqs_sin: Optional[torch.Tensor] = None,
        guidance: torch.Tensor = None,  # Guidance for modulation, should be cfg_scale x 1000.
        stg_mode: str = None,
        stg_block_idx: int = -1,
        return_dict: bool = True,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        
        def _process_double_blocks(img, txt, vec, block_args):
            for b, block in enumerate(self.double_blocks):
                if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
                    block.to(self.main_device)
                    
                img, txt = block(img, txt, vec, *block_args)
                
                if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0:
                    block.to(self.offload_device, non_blocking=True)
            return img, txt

        def _process_single_blocks(x, vec, txt_seq_len, block_args, stg_mode=None, stg_block_idx=None):
            for b, block in enumerate(self.single_blocks):
                if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
                    block.to(self.main_device)
                    
                curr_stg_mode = stg_mode if b == stg_block_idx else None
                x = block(x, vec, txt_seq_len, *block_args, curr_stg_mode)
                
                if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0:
                    block.to(self.offload_device, non_blocking=True)
            return x
        
        out = {}
        img = x
        txt = text_states
        _, _, ot, oh, ow = x.shape
        tt, th, tw = (
            ot // self.patch_size[0],
            oh // self.patch_size[1],
            ow // self.patch_size[2],
        )
        set_num_frames(img.shape[2])

        current_dims = (ot, oh, ow)

        # Check if dimensions changed since last run
        if not hasattr(self, 'last_dims') or self.last_dims != current_dims:
            # Reset TeaCache state on dimension change
            self.cnt = 0
            self.accumulated_rel_l1_distance = 0
            self.previous_modulated_input = None
            self.previous_residual = None
            self.last_dims = current_dims

        # Prepare modulation vectors.
        vec = self.time_in(t)

        if self.i2v_condition_type == "token_replace":
            token_replace_t = torch.zeros_like(t)
            token_replace_vec = self.time_in(token_replace_t)
            first_frame_token_num = th * tw
        else:
            token_replace_vec = None
            first_frame_token_num = None
            # token_replace_mask_img = None
            # token_replace_mask_txt = None

        # text modulation
        if text_states_2 is not None:
            vec_2 = self.vector_in(text_states_2)
            vec = vec + vec_2
            if self.i2v_condition_type == "token_replace":
                token_replace_vec = token_replace_vec + vec_2


        # guidance modulation
        if guidance is not None:
            # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
            vec = vec + self.guidance_in(guidance)

        # Embed image and text.
        if self.offload_txt_in:
            self.txt_in.to(self.main_device)
        if self.offload_img_in:
            self.img_in.to(self.main_device)

        img = self.img_in(img)
        if self.text_projection == "linear":
            txt = self.txt_in(txt)
        elif self.text_projection == "single_refiner":
            txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
        else:
            raise NotImplementedError(
                f"Unsupported text_projection: {self.text_projection}"
            )
        if self.offload_txt_in:
            self.txt_in.to(self.offload_device, non_blocking=True)
        if self.offload_img_in:
            self.img_in.to(self.offload_device, non_blocking=True)

        max_seqlen_q, max_seqlen_kv, attn_mask, cu_seqlens_q, cu_seqlens_kv = None, None, None, None, None
        txt_seq_len = txt.shape[1]
        img_seq_len = img.shape[1]

        if "varlen" in self.attention_mode: #just for backwards compatibility
            max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len
            text_mask = torch.ones((1, text_states.shape[1]), dtype=torch.bool, device=text_states.device)
            # Compute cu_squlens for flash attention
            cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
            cu_seqlens_kv = cu_seqlens_q

        freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None

        block_args = [
            cu_seqlens_q, 
            cu_seqlens_kv, 
            max_seqlen_q, 
            max_seqlen_kv, 
            freqs_cis, 
            attn_mask, 
            self.upcast_rope, 
            token_replace_vec, 
            first_frame_token_num, 
            self.i2v_condition_type
            ]

        #tea_cache
        if self.enable_teacache:
            inp = img.clone()
            vec_ = vec.clone()
            txt_ = txt.clone()
            self.double_blocks[0].to(self.main_device)
            (
                img_mod1_shift,
                img_mod1_scale,
                img_mod1_gate,
                img_mod2_shift,
                img_mod2_scale,
                img_mod2_gate,
            ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
            normed_inp = self.double_blocks[0].img_norm1(inp)
            modulated_inp = modulate(
                normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
            )

            if self.cnt == 0 or self.cnt == self.num_steps-1:
                should_calc = True
                self.accumulated_rel_l1_distance = 0
                self.previous_modulated_input = modulated_inp.clone()
            else:
                coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
                rescale_func = np.poly1d(coefficients)
                self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
                if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
                    should_calc = False
                else:
                    should_calc = True
                    self.accumulated_rel_l1_distance = 0
            self.previous_modulated_input = modulated_inp.clone()
            self.cnt += 1
            if self.cnt == self.num_steps:
                self.cnt = 0

            if not should_calc and self.previous_residual is not None:
                self.teacache_skipped_steps += 1
                # Verify tensor dimensions match before adding
                if img.shape == self.previous_residual.shape:
                    img = img + self.previous_residual.to(img.device)
                else:
                    should_calc = True # Force recalculation if dimensions don't match

            if should_calc:
                ori_img = img.clone()
                # Pass through DiT blocks
                img, txt = _process_double_blocks(img, txt, vec, block_args)
                # Merge txt and img to pass through single stream blocks.
                x = torch.cat((img, txt), 1)
                x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx)

                img = x[:, :img_seq_len, ...]
                self.previous_residual = (img - ori_img).to(self.teacache_device)
        else:
            # Pass through DiT blocks
            img, txt = _process_double_blocks(img, txt, vec, block_args)
            # Merge txt and img to pass through single stream blocks.
            x = torch.cat((img, txt), 1)
            x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx)
            img = x[:, :img_seq_len, ...]

        # ---------------------------- Final layer ------------------------------
        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)

        img = self.unpatchify(img, tt, th, tw)
        if return_dict:
            out["x"] = img
            return out
        return img

    def unpatchify(self, x, t, h, w):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.unpatchify_channels
        pt, ph, pw = self.patch_size
        assert t * h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
        x = torch.einsum("nthwcopq->nctohpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

        return imgs

#################################################################################
#                             HunyuanVideo Configs                              #
#################################################################################

# HUNYUAN_VIDEO_CONFIG = {
#     "HYVideo-T/2": {
#         "mm_double_blocks_depth": 20,
#         "mm_single_blocks_depth": 40,
#         "rope_dim_list": [16, 56, 56],
#         "hidden_size": 3072,
#         "heads_num": 24,
#         "mlp_width_ratio": 4,
#     },
#     "HYVideo-T/2-cfgdistill": {
#         "mm_double_blocks_depth": 20,
#         "mm_single_blocks_depth": 40,
#         "rope_dim_list": [16, 56, 56],
#         "hidden_size": 3072,
#         "heads_num": 24,
#         "mlp_width_ratio": 4,
#         "guidance_embed": True,
#     },
# }
